#include "DAG.hpp"
#include <random>
#include <algorithm>
#include <iostream>
#include <fstream>
#include <set>
#include <queue>
#include <stack>

using namespace std;

DirectedAcyclicGraph::DirectedAcyclicGraph(int numNodes, int numEdges)
    : n(numNodes), m(numEdges), MAXEDGE(numNodes * (numNodes - 1) / 2) {}

void DirectedAcyclicGraph::BuildCompleteDAG() {
    DAG.assign(n, vector<int>());
    DAG_R.assign(n, vector<int>());
    indegree.assign(n, 0);
    indegree_r.assign(n, 0);
    edges.clear();
    const int total_edges = n * (n - 1) / 2;
    vector<pair<int, int>> all_edges;
    all_edges.reserve(total_edges);
    int edge_count = 0;
    for (int u = 0; u < n; ++u) {
        for (int v = u + 1; v < n; ++v) {
            all_edges.emplace_back(u, v);
            edge_count++;
        }
    }
    random_device rd;
    mt19937 gen(rd());
    shuffle(all_edges.begin(), all_edges.end(), gen);
    for (int i = 0; i < m; ++i) {
        int u = all_edges[i].first;
        int v = all_edges[i].second;
        DAG[u].push_back(v);
        DAG_R[v].push_back(u);
        edges.emplace_back(u, v);
        indegree[v]++;
        indegree_r[u]++;
    }
}
vector<vector<int>> DirectedAcyclicGraph::GreedyChainDecomposition() {
    vector<vector<int>> chains;
    vector<bool> visited(n, false);

    for (int u = 0; u < n; u++) {
        if (!visited[u]) {
            vector<int> chain;
            int cur = u;

            while (cur != -1 && !visited[cur]) {
                visited[cur] = true;
                chain.push_back(cur);

                int next = -1;
                for (int v : DAG[cur]) {
                    if (!visited[v]) {
                        next = v;
                        break;
                    }
                }
                cur = next;
            }
            chains.push_back(chain);
        }
    }
    return chains;
}

vector<vector<int>> DirectedAcyclicGraph::ChainDecomposition(const vector<int>& match, const vector<int>& inv_match) {
    vector<vector<int>> chains;
    vector<int> nextNode(n, -1);
    vector<bool> visited(n, false);

    for (int v = 0; v < n; v++) {
        int u = inv_match[v];
        if (u != -1) {
            nextNode[u] = v;
        }
    }

    for (int start = 0; start < n; start++) {
        if (visited[start] || (inv_match[start] != -1)) continue;

        vector<int> path;
        int cur = start;

        while (cur != -1 && !visited[cur]) {
            path.push_back(cur);
            visited[cur] = true;
            cur = nextNode[cur];
        }
        if (!path.empty()) {
            chains.push_back(path);
        }
    }

    for (int v = 0; v < n; v++) {
        if (!visited[v]) {
            chains.push_back({v});
            visited[v] = true;
        }
    }
    return chains;
}


void DirectedAcyclicGraph::Reachable(int x, std::unordered_set<int>& left, std::unordered_set<int>& right, const std::vector<bool>& active) const {
    int n = active.size();
    vector<bool> visited(n, false);
    queue<int> q;
    visited[x] = true;
    q.push(x);
    while (!q.empty()) {
        int u = q.front();
        q.pop();
        for (int v : DAG[u]){
            if (!visited[v]) {
                q.push(v);
                visited[v] = true;
                if(active[v]){
                    right.insert(v);
                }
            }
        }
    }
    visited.assign(n, false);
    visited[x] = true;
    q.push(x);
    while (!q.empty()) {
        int u = q.front();
        q.pop();
        for (int v : DAG_R[u]) {
            if (!visited[v]) {
                q.push(v);
                visited[v] = true;
                if(active[v]){
                    left.insert(v);
                }
            }
        }
    }
}
int DirectedAcyclicGraph::comparison_count_sort = 0;
bool DirectedAcyclicGraph::custom_compare(const int& a, const int& b) {
    ++comparison_count_sort;
    return a < b;
}
vector<int> DirectedAcyclicGraph::Median(vector<bool>& active, vector<bool>& local_active) const {
    int n = GetNumNodes();
    vector<bool> visited(n, false);
    vector<int> medians;
    vector<int> remaining;

    for (int i = 0; i < n; i++) {
        if (!visited[i] && active[i]) {
            vector<int> group;
            stack<int> s;
            s.push(i);
            visited[i] = true;

            while (!s.empty()) {
                int u = s.top();
                s.pop();
                group.push_back(u);

                for (int v : DAG[u]) {
                    if (!visited[v] && active[v]) {
                        s.push(v);
                        visited[v] = true;
                    }
                }

                if (group.size() == 5) {
                    sort(group.begin(), group.end(), custom_compare);
                    medians.push_back(group[2]);
                    local_active[group[2]] = true;
                    group.clear();
                }
            }

            if (!group.empty()) {
                remaining.insert(remaining.end(), group.begin(), group.end());
            }
        }
    }

    comparison_count_sort = 0;
    while (remaining.size() >= 5) {
        vector<int> temp_group(remaining.end() - 5, remaining.end());
        remaining.erase(remaining.end() - 5, remaining.end());
        sort(temp_group.begin(), temp_group.end(), custom_compare);
        medians.push_back(temp_group[2]);
        local_active[temp_group[2]] = true;
    }
    if (!remaining.empty() && medians.empty()) {
        sort(remaining.begin(), remaining.end(), custom_compare);
        int median = remaining[remaining.size() / 2];
        medians.push_back(median);
        local_active[median] = true;
    }
    if (medians.empty()) {
        cerr << "Warning: DirectedAcyclicGraph::Median found no active nodes.\n";
    }
    return medians;
}

const vector<vector<int>>& DirectedAcyclicGraph::GetDAG() const { return DAG; }
const vector<vector<int>>& DirectedAcyclicGraph::GetDAG_R() const { return DAG_R; }
const vector<pair<int, int>>& DirectedAcyclicGraph::GetEdges() const { return edges; }
const vector<int>& DirectedAcyclicGraph::GetIndegree() const { return indegree; }
const vector<int>& DirectedAcyclicGraph::GetIndegreeR() const { return indegree_r; }
const vector<int> DirectedAcyclicGraph::GetindexTolabel() const {return indexTolabel;}
int DirectedAcyclicGraph::GetMaxEdge() const { return MAXEDGE; }
int DirectedAcyclicGraph::GetNumNodes() const { return n; }
int DirectedAcyclicGraph::GetNumEdges() const { return m; }
int DirectedAcyclicGraph::GetComparisonCountSort() const{return comparison_count_sort;}
